<!-- Copyright 2019 Google LLC. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================-->

<html>

<head>
  <title>TensorFlow.js Model Benchmark</title>
  <link href="https://fonts.googleapis.com/css?family=Roboto" rel="stylesheet">
  <link href="./main.css" rel="stylesheet">
  <script src="https://cdnjs.cloudflare.com/ajax/libs/dat-gui/0.7.7/dat.gui.min.js"></script>
  <link rel="shortcut icon" href="https://www.gstatic.com/devrel-devsite/prod/vf4ca28c48392b1412e7b030290622a0dd55b62dec1202c59f119b1e23227c988/tensorflow/images/favicon.png">
</head>

<body>
  <h2>TensorFlow.js Model Benchmark</h2>
  <div id="modal-msg"></div>
  <div id="container">
    <div id="stats">
      <div class="box">
        <pre id="env"></pre>
      </div>
      <table class="table" id="parameters">
        <caption>Parameters</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <table class="table" id="modelInputs" style="display: none;">
        <caption>Inputs</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody></tbody>
      </table>
      <table class="table" id="timings">
        <caption>Model</caption>
        <thead>
          <tr>
            <th>Type</th>
            <th>Value</th>
          </tr>
        </thead>
        <tbody>
        </tbody>
      </table>
      <div class="box" id="perf-trendline-container">
        <div class="label">Inference times</div>
        <div class="trendline">
          <div class="yMax"></div>
          <div class="yMin"></div>
          <svg>
            <path></path>
          </svg>
        </div>
      </div>
    </div>
    <table class="table" id="kernels">
      <caption>Kernel</caption>
      <thead id="kernels-thead">
        <tr>
          <th>Type</th>
          <th>Time(ms)</th>
        </tr>
      </thead>
      <tbody></tbody>
    </table>
  </div>
  <script src="https://unpkg.com/@tensorflow/tfjs-core@latest/dist/tf-core.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-backend-cpu@latest/dist/tf-backend-cpu.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-backend-webgl@latest/dist/tf-backend-webgl.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-layers@latest/dist/tf-layers.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-converter@latest/dist/tf-converter.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-backend-wasm@latest/dist/tf-backend-wasm.js"></script>
  <script src="https://unpkg.com/@tensorflow/tfjs-automl@latest/dist/tf-automl.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/universal-sentence-encoder"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/speech-commands"> </script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/posenet@2"></script>
  <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/body-pix@2"></script>

  <script src="../model_config.js"></script>
  <script src="../benchmark_util.js"></script>
  <script src="./util.js"></script>
  <script src="./index.js"></script>
  <script>
    'use strict';

    const state = {
      numRuns: 50,
      benchmark: 'mobilenet_v2',
      run: (v) => {
        runBenchmark().catch(e => {
          showMsg('Error: ' + e.message);
          throw e;
        });
      },
      testCorrectness: async () => {
        if (state.benchmark === 'custom') {
          // TODO: seperate input generation from the predict function
          await showMsg('Testing correctness for custom models is not supported.');
          throw new Error();
        }

        let match, predictionData, referenceData;
        await cleanUpTable();

        // load model and run inference
        try {
          await loadModelAndRecordTime();
          await showMsg('Testing correctness');
          await showInputs();
          await showCorrectnessTestParameters();

          tf.setBackend('cpu');
          await showMsg('Runing on cpu');
          const referencePrediction = predict(model);
          referenceData = await getPredictionData(referencePrediction);

          await tf.setBackend(state.backend);
          await showMsg(`Runing on ${state.backend}`);
          const prediction = predict(model);
          predictionData = await getPredictionData(prediction);
        } catch (e) {
          showMsg('Error: ' + e.message);
          throw e;
        }

        // compare results
        try {
          await showMsg(null);
          tf.test_util.expectArraysClose(referenceData, predictionData);
          match = true;
        } catch (e) {
          match = false;
          throw e;
        } finally {
          appendRow(timeTable, `backend`, state.backend);
          appendRow(timeTable, `Prediction matches CPU`, match);
          appendRow(timeTable, '', '');
        }
      },
      backend: 'wasm',
      kernelTiming: 'aggregate',
      modelType: '',
      modelUrl: '',
      isModelChanged: false,
      flags: {},
      isFlagChanged: false,
      isModelLoaded: false,
      inputs: {}
    };

    const modalDiv = document.getElementById('modal-msg');
    const timeTable = document.querySelector('#timings tbody');
    const envDiv = document.getElementById('env');
    const kernelsTableHead = document.getElementById('kernels-thead');
    const kernelTable = document.querySelector('#kernels tbody');
    const parameterTable = document.querySelector('#parameters tbody');
    const modelInputsTable = document.querySelector('#modelInputs tbody');

    let model, predict, chartWidth;

    async function showMsg(message) {
      if (message != null) {
        modalDiv.innerHTML = message + '...';
        modalDiv.style.display = 'block';
      } else {
        modalDiv.style.display = 'none';
      }
      await tf.nextFrame();
      await tf.nextFrame();
    }

    function showVersions() {
      envDiv.innerHTML = JSON.stringify({
        core: tf.version_core,
        layers: tf.version_layers,
        converter: tf.version_converter
      }, null, 2);
    }

    async function showEnvironment() {
      await tf.time(() => tf.add(tf.tensor1d([1]), tf.tensor1d([1])).data());
      const sortedFeatures = Object.keys(tf.env().features).sort().map(
          name=>`  ${name}: ${tf.env().features[name]}`);
      envDiv.innerHTML += `<br/>{<br>${sortedFeatures.join('<br>')}<br>}`;
    }

    function showBenchmarkingParameters() {
      appendRow(parameterTable, 'task', 'Performance Benchmark');
      appendRow(parameterTable, 'benchmark', state.benchmark);
      appendRow(parameterTable, 'modelType', state.modelType || 'Unknown');
      if (state.benchmark === 'custom') {
        appendRow(parameterTable, 'modelUrl', state.modelUrl);
      }
      appendRow(parameterTable, 'numRuns', state.numRuns);
      appendRow(parameterTable, 'backend', state.backend);
      appendRow(parameterTable, 'kernelTiming', state.kernelTiming);
    }

    async function setupKernelTable() {
      kernelsTableHead.innerText = '';
      kernelTable.innerHTML = '';
      const rows = ['<b>Kernel</b>', '<b>Time(ms)</b>'];
      if (state.kernelTiming === 'individual') {
        rows.push('<b>Inputs</b>', '<b>Output</b>');
        if (state.backend === 'webgl') {
          rows.push('<b>GPUPrograms</b>');
        }
      }
      appendRow(kernelsTableHead, ...rows);

      await tf.nextFrame();
    }

    async function showCorrectnessTestParameters() {
      appendRow(parameterTable, 'task', 'Correctness Test');
      appendRow(parameterTable, 'benchmark', state.benchmark);
      appendRow(parameterTable, 'modelType', state.modelType || 'Unknown');
      if (state.benchmark === 'custom') {
        appendRow(parameterTable, 'modelUrl', state.modelUrl);
      }
      appendRow(parameterTable, 'numRuns', 1);
      appendRow(parameterTable, 'backend', state.backend);

      await tf.nextFrame();
    }

    function appendRow(tbody, ...cells) {
      const tr = document.createElement('tr');
      cells.forEach(c => {
        const td = document.createElement('td');
        if (c instanceof HTMLElement) {
          td.appendChild(c);
        } else {
          td.innerHTML = c;
        }
        tr.appendChild(td);
      });
      tbody.appendChild(tr);
    }

    async function warmUpAndRecordTime() {
      await showMsg('Warming up');

      let timeInfo;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          timeInfo = await timeModelInference(model, input, 1);
        } finally {
          tf.dispose(input);
        }
      } else {
        timeInfo = await timeInference(() => predict(model), 1);
      }

      await showMsg(null);
      appendRow(timeTable, 'backend', state.backend);
      appendRow(timeTable, '1st inference', printTime(timeInfo.times[0]));
    }

    async function showInputs() {
      if (modelInputsTable.innerHTML !== '') {
        // the input table for the model have been constructed
        return;
      }
      let variableShape = false;
      const modelInputsElement = document.getElementById('modelInputs');
      if (model['inputs'] == null) {
        modelInputsElement.style.display='none';
      } else {
        modelInputsElement.style.display='';
        // construct the input table for the new model
        state.inputs.forEach((modelInput, inputIndex) => {
          if (modelInput.shape != null) {
            let shape = modelInput.shape;
            variableShape = modelInput.shape.some(e => e == -1);
            if (variableShape && state.benchmark === 'custom') {
              const shapeInput = document.createElement("INPUT");
              shapeInput.setAttribute('type', 'text');
              shapeInput.setAttribute('value', shape);
              shapeInput.addEventListener('focusout', () => {
                state.inputs[inputIndex].shape =
                    shapeInput.value.split(',').map(x => +x);
              });
              shape = shapeInput;
            }
            appendRow(modelInputsTable, 'name', modelInput.name);
            appendRow(modelInputsTable, 'shape', shape);
            appendRow(modelInputsTable, 'data type', modelInput.dtype);
            // add data range for input generator
            if (state.benchmark === 'custom') {
              const minInput = createRangeInput(state.inputs[inputIndex].range,
                    0);
              const maxInput = createRangeInput(state.inputs[inputIndex].range,
                    1);
              const seperator = document.createElement('span');
              seperator.textContent = ' - '
              const div = document.createElement('div');
              div.appendChild(minInput);
              div.appendChild(seperator);
              div.appendChild(maxInput);
              appendRow(modelInputsTable, 'data range', div);
            }
            appendRow(modelInputsTable, '', '');
          }
        });
      }

      await tf.nextFrame();
      return state.benchmark === 'custom' &&  variableShape;
    }

    function createRangeInput(range, index) {
      const input = document.createElement("INPUT");
      input.setAttribute('type', 'text');
      input.setAttribute('value', range[index]);
      input.addEventListener('focusout', () => {
        range[index] = +input.value;
      });
      return input;
    }
    async function loadModelAndRecordTime() {
      const benchmark = benchmarks[state.benchmark];
      state.modelType = benchmark['type'];
      state.isModelChanged = false; // used to clean the performance history

      if (benchmark['load'] == null) {
        throw new Error(`Please provide a load method for '${state.benchmark}' model.`);
      }

      await showMsg('Loading the model');
      let start = performance.now();
      model = await benchmark.load();
      state.inputs = [];
      if (model.inputs) {
        // construct the input state for the model
        for (let modelInputIndex = 0; modelInputIndex < model.inputs.length; modelInputIndex++) {
            let modelInput = model.inputs[modelInputIndex];
            if (modelInput.shape == null) continue;
            let shape = modelInput.shape.map(e => e == null ? -1 : e);
            state.inputs.push({
              name: modelInput.name,
              shape: [...shape],
              dtype: modelInput.dtype,
              range: [0, 1000]
            });
        }
      }
      predict = benchmark.predictFunc();
      const elapsed = performance.now() - start;

      await showMsg(null);
      appendRow(timeTable, 'Model load', printTime(elapsed));
    }

    const chartHeight = 150;
    function populateTrendline(node, data, forceYMinToZero = false, yFormatter = d => d) {
      node.querySelector('svg').setAttribute('width', chartWidth);
      node.querySelector('svg').setAttribute('height', chartHeight);

      const yMax = Math.max(...data);
      let yMin = forceYMinToZero ? 0 : Math.min(...data);
      if (yMin === yMax) {
        yMin = 0;
      }

      node.querySelector('.yMin').textContent = yFormatter(yMin);
      node.querySelector('.yMax').textContent = yFormatter(yMax);

      const xIncrement = chartWidth / (data.length - 1);
      node.querySelector('path')
        .setAttribute('d', `M${data.map((d, i) => `${i * xIncrement},${chartHeight - ((d - yMin) / (yMax - yMin)) * chartHeight}`).join('L')} `);
    }

    async function measureAveragePredictTime() {
      await showMsg(`Running predict ${state.numRuns} times`);
      chartWidth = document.querySelector('#perf-trendline-container').getBoundingClientRect().width;

      let timeInfo;
      const numRuns = state.numRuns;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          timeInfo = await timeModelInference(model, input, numRuns);
        } finally {
          tf.dispose(input);
        }
      } else {
        timeInfo = await timeInference(() => predict(model), numRuns);
      }

      const forceInferenceTrendYMinToZero = true;
      populateTrendline(document.querySelector('#perf-trendline-container'), timeInfo.times, forceInferenceTrendYMinToZero, printTime);

      await showMsg(null);
      appendRow(timeTable, `Subsequent average(${state.numRuns} runs)`, printTime(timeInfo.averageTime));
      appendRow(timeTable, 'Best time', printTime(timeInfo.minTime));
      // Add an empty row at the end of a round of benchmark.
      appendRow(timeTable, '', '');
    }

    async function profileMemoryAndKernelTime() {
      await showMsg('Profile memory and kernels');
      const start = performance.now();

      let profileInfo;
      if (state.benchmark === 'custom') {
        const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
        try {
          profileInfo = await profileModelInference(model, input);
        } finally {
          tf.dispose(input);
        }
      } else {
        profileInfo = await profileInference(() => predict(model));
      }

      const elapsed = performance.now() - start;
      await showMsg(null);
      appendRow(timeTable, 'Peak memory', printMemory(profileInfo.peakBytes));
      appendRow(timeTable, 'Leaked tensors', profileInfo.newTensors);
      appendRow(timeTable, '2nd inference', printTime(elapsed));

      if (state.backend != 'webgl' || queryTimerIsEnabled()) {
        await showMsg('Profiling kernels');
        appendRow(timeTable, 'Number of kernels', profileInfo.kernels.length);
        showKernelTime(profileInfo);
      } else {
        await showMsg('Skipping kernel times since query timer extension is not ' +
          'available. <br/> Please use a browser supporting the EXT_disjoint_timer_query WebGL API.');
      };
    }

    function showKernelTime(profileInfo) {
      const tbody = document.querySelector('#kernels tbody');
      if (state.kernelTiming === 'individual') {
        profileInfo.kernels.forEach(kernel => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', kernel.name);
          nameSpan.textContent = kernel.name;
          let inputInfo;
          kernel.inputShapes.forEach((inputShape, index) => {
            if (inputInfo == null) {
              inputInfo = '';
            } else {
              inputInfo += '<br>';
            }
            // TODO: Add the name for each input node in the returned value of
            // `tf.profile` at https://github.com/tensorflow/tfjs/blob/master/tfjs-core/src/engine.ts#L657.
            if (inputShape == null) {
              inputInfo += `input${index}: null`;
            } else {
              inputInfo += `input${index}: ${inputShape.length}D[${inputShape}]`;
            }
          });
          appendRow(tbody, nameSpan, kernel.kernelTimeMs.toFixed(2), inputInfo, kernel.outputShapes, kernel.extraInfo);
        });
      } else {
        profileInfo.aggregatedKernels.forEach(r => {
          const nameSpan = document.createElement('span');
          nameSpan.setAttribute('title', r.name);
          nameSpan.textContent = r.name;
          appendRow(tbody, nameSpan, r.timeMs.toFixed(2));
        });
      }
    }

    async function cleanUpTable() {
      parameterTable.innerHTML = '';
      if (state.isModelChanged) {
        timeTable.innerHTML = '';
        kernelTable.innerHTML = '';
        modelInputsTable.innerHTML = '';
      }
      await tf.nextFrame();
    }

    async function loadModel() {
      if (state.isModelChanged || !state.isModelLoaded || state.isFlagChanged) {
        await loadModelAndRecordTime();
        state.isModelLoaded = true;
        const needInputShape = await showInputs();
        if (needInputShape) {
          showMsg('Please set the variable input shapes, update -1 to the desired value.');
          return false;
        }
      }
      return true;
    }

    async function runBenchmark() {
      if (state.isFlagChanged) {
        await setEnvFlags(state.flags);
        envDiv.innerHTML = '';
        showVersions();
        await showEnvironment();
      }
      await cleanUpTable();
      await setupKernelTable();

      const canProceed = await loadModel();
      if (state.isFlagChanged) {
        state.isFlagChanged = false;
      }
      // return immediately if load model failed or need user input.
      if (!canProceed) {
        return;
      }
      await showBenchmarkingParameters();

      await warmUpAndRecordTime();
      await showMsg('Waiting for GC');
      await sleep(1000);
      await profileMemoryAndKernelTime();
      await sleep(200);
      await measureAveragePredictTime();
    }


    async function onPageLoad() {
      const gui = new dat.gui.GUI({width: 550});
      gui.domElement.id = 'gui';

      await tf.setBackend(state.backend);

      const modelFolder = gui.addFolder('Benchmark');
      let modelUrlController = null;
      modelFolder.add(state, 'benchmark', Object.keys(benchmarks)).name('models').onChange(async model => {
        await showMsg(null);
        state.isModelChanged = true;
        state.isModelLoaded = false;
        if (model === 'custom') {
          if (modelUrlController === null) {
            modelUrlController = modelFolder.add(state, 'modelUrl').onChange(async modelUrl => {
              state.isModelChanged = true;
              state.isModelLoaded = false;
              await showMsg(null);
            });
            modelUrlController.domElement.querySelector('input').placeholder =
                'https://your-domain.com/model-path/model.json';
          }
        } else if (modelUrlController != null) {
          modelFolder.remove(modelUrlController);
          modelUrlController = null;
        }
      });
      modelFolder.open();

      const parameterFolder = gui.addFolder('Parameters');
      parameterFolder.add(state, 'numRuns');
      parameterFolder.add(state, 'kernelTiming', ['aggregate', 'individual']);
      parameterFolder.open();

      const envFolder = gui.addFolder('Environment');
      envFolder.add(state, 'backend', ['wasm', 'webgl', 'cpu']).onChange(async backend => {
        // TODO: Failed to set backend
        await tf.setBackend(backend);
        await showFlagSettings(envFolder, state.backend);
      });
      envFolder.open();

      gui.add(state, 'run').name('Run benchmark');
      gui.add(state, 'testCorrectness').name('Test correctness');

      await showFlagSettings(envFolder, state.backend);
      showVersions();
      await showEnvironment();
    }

    onPageLoad();
  </script>
</body>

</html>
